package internal

import (
	"bytes"
	"context"
	"fmt"
	"go/ast"
	"go/printer"
	"go/token"
	"go/types"
	"sort"
	"strconv"
	"strings"
	"unicode"
	"unicode/utf8"

	"github.com/xuender/kit/los"
	"golang.org/x/tools/go/ast/astutil"
	"golang.org/x/tools/go/packages"
)

type BuildKey int

const (
	KeyVersion BuildKey = iota
	KeyDate
)

// Gen is the file-bindgen generator state.
type Gen struct {
	pkg         *packages.Package
	buf         bytes.Buffer
	imports     map[string]ImportInfo
	anonImports map[string]bool
	values      map[ast.Expr]string
	binder      *Bind
	version     string
	date        string
}

func NewGen(ctx context.Context, pkg *packages.Package) *Gen {
	version := ctx.Value(KeyVersion)
	date := ctx.Value(KeyDate)

	if version == nil {
		version = "unknown"
	}

	if date == nil {
		date = ""
	}

	return &Gen{
		pkg:         pkg,
		anonImports: make(map[string]bool),
		imports:     make(map[string]ImportInfo),
		values:      make(map[ast.Expr]string),
		binder:      NewBind(),
		version:     version.(string),
		date:        date.(string),
	}
}

// ImportInfo holds info about an import.
type ImportInfo struct {
	// name is the identifier that is used in the generated source.
	name string
	// differs is true if the import is given an identifier that does not
	// match the package's identifier.
	differs bool
}

func (g *Gen) Pf(format string, args ...any) {
	fmt.Fprintf(&g.buf, format, args...)
}

// writeAST prints an AST node into the generated output, rewriting any
// package references it encounters.
func (g *Gen) writeAST(info *types.Info, node ast.Node) {
	node = g.rewritePkgRefs(info, node)
	if err := printer.Fprint(&g.buf, g.pkg.Fset, node); err != nil {
		panic(err)
	}
}

// rewritePkgRefs rewrites any package references in an AST into references for the
// generated package.
func (g *Gen) rewritePkgRefs(info *types.Info, node ast.Node) ast.Node {
	start, end := node.Pos(), node.End()
	node = copyAST(node)
	// First, rewrite all package names. This lets us know all the
	// potentially colliding identifiers.
	node = astutil.Apply(node, func(cur *astutil.Cursor) bool {
		switch node := cur.Node().(type) {
		case *ast.Ident:
			// This is an unqualified identifier (qualified identifiers are peeled off below).
			obj := info.ObjectOf(node)
			if obj == nil {
				return false
			}

			if pkg := obj.Pkg(); pkg != nil && obj.Parent() == pkg.Scope() && pkg.Path() != g.pkg.PkgPath {
				// An identifier from either a dot import or read from a different package.
				newPkgID := g.qualifyImport(pkg.Name(), pkg.Path())
				cur.Replace(&ast.SelectorExpr{
					X:   ast.NewIdent(newPkgID),
					Sel: ast.NewIdent(node.Name),
				})

				return false
			}

			return true
		case *ast.SelectorExpr:
			pkgIdent, isOk := node.X.(*ast.Ident)
			if !isOk {
				return true
			}

			pkgName, isOk := info.ObjectOf(pkgIdent).(*types.PkgName)
			if !isOk {
				return true
			}
			// This is a qualified identifier. Rewrite and avoid visiting subexpressions.
			imported := pkgName.Imported()
			newPkgID := g.qualifyImport(imported.Name(), imported.Path())
			cur.Replace(&ast.SelectorExpr{
				X:   ast.NewIdent(newPkgID),
				Sel: ast.NewIdent(node.Sel.Name),
			})

			return false
		default:
			return true
		}
	}, nil)
	// Now that we have all the identifiers, rename any variables declared
	// in this scope to not collide.
	newNames := make(map[types.Object]string)
	inNewNames := func(str string) bool {
		for _, other := range newNames {
			if other == str {
				return true
			}
		}

		return false
	}

	var scopeStack []*types.Scope

	pkgScope := g.pkg.Types.Scope()

	node = astutil.Apply(node, func(cur *astutil.Cursor) bool {
		if scope := info.Scopes[cur.Node()]; scope != nil {
			scopeStack = append(scopeStack, scope)
		}

		id, isOk := cur.Node().(*ast.Ident)
		if !isOk {
			return true
		}

		obj := info.ObjectOf(id)
		if obj == nil {
			// We rewrote this identifier earlier, so it does not need
			// further rewriting.
			return true
		}

		if n, isOk := newNames[obj]; isOk {
			// We picked a new name for this symbol. Rewrite it.
			cur.Replace(ast.NewIdent(n))

			return false
		}

		if par := obj.Parent(); par == nil || par == pkgScope {
			// Don't rename methods, field names, or top-level identifiers.
			return true
		}
		// Rename any symbols defined within rewritePkgRefs's node that conflict
		// with any symbols in the generated file.
		objName := obj.Name()
		if pos := obj.Pos(); pos < start || end <= pos || !(g.nameInFileScope(objName) || inNewNames(objName)) {
			return true
		}

		newName := disambiguate(objName, func(str string) bool {
			if g.nameInFileScope(str) || inNewNames(str) {
				return true
			}

			if len(scopeStack) > 0 {
				if _, obj := scopeStack[len(scopeStack)-1].LookupParent(str, token.NoPos); obj != nil {
					return true
				}
			}

			return false
		})

		newNames[obj] = newName
		cur.Replace(ast.NewIdent(newName))

		return false
	}, func(cur *astutil.Cursor) bool {
		if info.Scopes[cur.Node()] != nil {
			// Should be top of stack; pop it.
			scopeStack = scopeStack[:len(scopeStack)-1]
		}

		return true
	})

	return node
}

// Frame bakes the built up source body into an unformatted Go source file.
func (g *Gen) Frame(tags string) []byte {
	if g.buf.Len() == 0 {
		return nil
	}

	var buf bytes.Buffer

	if len(tags) > 0 {
		tags = fmt.Sprintf(" gen -tags \"%s\"", tags)
	}

	buf.WriteString("// Code generated by Bindgen. DO NOT EDIT.\n\n")
	buf.WriteString("// version:    " + g.version)
	buf.WriteString("// build time: " + g.date)
	buf.WriteString("//go:generate go run gitee.com/xuender/bindgen/cmd/bindgen" + tags + "\n")
	buf.WriteString("//+build !bindgen\n\n")
	buf.WriteString("package " + g.pkg.Name + "\n\n")

	if len(g.imports) > 0 {
		buf.WriteString("import (\n")

		imps := make([]string, 0, len(g.imports))

		for path := range g.imports {
			imps = append(imps, path)
		}

		sort.Strings(imps)

		for _, path := range imps {
			info := g.imports[path]
			if info.differs {
				fmt.Fprintf(&buf, "\t%s %q\n", info.name, path)
			} else {
				fmt.Fprintf(&buf, "\t%q\n", path)
			}
		}

		buf.WriteString(")\n\n")
	}

	if len(g.anonImports) > 0 {
		buf.WriteString("import (\n")

		anonImps := make([]string, 0, len(g.anonImports))

		for path := range g.anonImports {
			anonImps = append(anonImps, path)
		}

		sort.Strings(anonImps)

		for _, path := range anonImps {
			fmt.Fprintf(&buf, "\t_ %s\n", path)
		}

		buf.WriteString(")\n\n")
	}

	buf.Write(g.buf.Bytes())

	return buf.Bytes()
}

func (g *Gen) nameInFileScope(name string) bool {
	for _, other := range g.imports {
		if other.name == name {
			return true
		}
	}

	for _, other := range g.values {
		if other == name {
			return true
		}
	}

	_, obj := g.pkg.Types.Scope().LookupParent(name, token.NoPos)

	return obj != nil
}

// disambiguate picks a unique name, preferring name if it is already unique.
// It also disambiguates against Go's reserved keywords.
func disambiguate(name string, collides func(string) bool) string {
	if !token.Lookup(name).IsKeyword() && !collides(name) {
		return name
	}

	buf := []byte(name)
	if len(buf) > 0 && buf[len(buf)-1] >= '0' && buf[len(buf)-1] <= '9' {
		buf = append(buf, '_')
	}

	base := len(buf)
	ten := 10

	for n := 2; ; n++ {
		buf = strconv.AppendInt(buf[:base], int64(n), ten)
		sbuf := string(buf)

		if !token.Lookup(sbuf).IsKeyword() && !collides(sbuf) {
			return sbuf
		}
	}
}

func (g *Gen) qualifyImport(name, path string) string {
	if path == g.pkg.PkgPath {
		return ""
	}

	const vendorPart = "vendor/"

	unvendored := path

	if i := strings.LastIndex(path, vendorPart); i != -1 && (i == 0 || path[i-1] == '/') {
		unvendored = path[i+len(vendorPart):]
	}

	if info, isOk := g.imports[unvendored]; isOk {
		return info.name
	}

	newName := disambiguate(name, func(str string) bool {
		return str == "err" || g.nameInFileScope(str)
	})

	g.imports[unvendored] = ImportInfo{
		name:    newName,
		differs: newName != name,
	}

	return newName
}

// bind emits the code for an binder.
func (g *Gen) bind(pos token.Pos, name string, ignores []string, sig *types.Signature, doc *ast.CommentGroup) []error {
	if err := funcOutput(sig); err != nil {
		return []error{notePosition(g.pkg.Fset.Position(pos),
			fmt.Errorf("inject %s: %w", name, err))}
	}

	bindPass(name, ignores, g.binder, sig, doc, &bindGen{
		g:       g,
		errVar:  disambiguate("err", g.nameInFileScope),
		discard: true,
	})
	bindPass(name, ignores, g.binder, sig, doc, &bindGen{
		g:       g,
		errVar:  disambiguate("err", g.nameInFileScope),
		discard: false,
	})

	// if len(pendingVars) > 0 {
	// 	g.P("var (\n")
	// 	for _, pv := range pendingVars {
	// 		g.P("\t%s = ", pv.name)
	// 		g.writeAST(pv.typeInfo, pv.expr)
	// 		g.P("\n")
	// 	}
	// 	g.P(")\n\n")
	// }
	return nil
}

// bindGen is the bind pass generator state.
type bindGen struct {
	g *Gen

	paramNames   []string
	localNames   []string
	cleanupNames []string
	errVar       string

	// discard causes ig.p and ig.writeAST to no-op. Useful to run
	// generation for side-effects like filling in g.imports.
	discard bool
}

func (ig *bindGen) Pf(format string, args ...any) {
	if ig.discard {
		return
	}

	ig.g.Pf(format, args...)
}

// bindPass generates an bind given the output from analysis.
// The sig passed in should be verified.
func bindPass(name string, ignores []string, bind *Bind, sig *types.Signature, doc *ast.CommentGroup, igen *bindGen) {
	params := sig.Params()
	los.Must0(funcOutput(sig))

	if doc != nil {
		for _, c := range doc.List {
			igen.Pf("%s\n", c.Text)
		}
	}

	start := 1
	target := params.At(0)

	if sig.Recv() == nil {
		igen.Pf("func %s(", name)
	} else {
		start = 0
		target = sig.Recv()
		sig.Recv().Pkg()

		igen.Pf("func (%s %s) %s(", target.Name(), getMod(target.Type().String(), sig.Recv().Pkg().Name()), name)
	}

	for idx := range params.Len() {
		if idx > 0 {
			igen.Pf(", ")
		}

		param := params.At(idx)

		aName := param.Name()
		if aName == "" || aName == "_" {
			aName = typeVariableName(param.Type(), "arg", unexport, igen.nameInInjector)
		} else {
			aName = disambiguate(aName, igen.nameInInjector)
		}

		igen.paramNames = append(igen.paramNames, aName)
		if sig.Variadic() && idx == params.Len()-1 {
			igen.Pf("%s ...%s", igen.paramNames[idx], types.TypeString(param.Type().(*types.Slice).Elem(), igen.g.qualifyPkg))
		} else {
			igen.Pf("%s %s", igen.paramNames[idx], types.TypeString(param.Type(), igen.g.qualifyPkg))
		}
	}

	igen.Pf(") {\n")

	for i := start; i < params.Len(); i++ {
		bind.Bind(igen, target, params.At(i), ignores)
	}

	igen.Pf("}\n\n")
}

func getMod(mod, pkg string) string {
	star := strings.HasPrefix(mod, "*")
	idx := strings.LastIndex(mod, "/")
	mod = mod[idx+1:]

	idx = strings.Index(mod, ".")
	if pkg == mod[:idx] {
		if star {
			return "*" + mod[idx+1:]
		}

		return mod[idx+1:]
	}

	if star {
		return mod
	}

	return mod
}

func typeVariableName(
	varType types.Type,
	defaultName string,
	transform func(string) string,
	collides func(string) bool,
) string {
	if p, isOk := varType.(*types.Pointer); isOk {
		varType = p.Elem()
	}

	var names []string

	switch tmp := varType.(type) {
	case *types.Basic:
		if tmp.Name() != "" {
			names = append(names, tmp.Name())
		}
	case *types.Named:
		obj := tmp.Obj()
		if name := obj.Name(); name != "" {
			names = append(names, name)
		}
		// Provide an alternate name prefixed with the package name if possible.
		// E.g., in case of collisions, we'll use "fooCfg" instead of "cfg2".
		if pkg := obj.Pkg(); pkg != nil && pkg.Name() != "" {
			names = append(names, fmt.Sprintf("%s%s", pkg.Name(), strings.ToTitle(obj.Name())))
		}
	}
	// If we were unable to derive a name, use defaultName.
	if len(names) == 0 {
		names = append(names, defaultName)
	}
	// Transform the name(s).
	for i, name := range names {
		names[i] = transform(name)
	}
	// See if there's an unambiguous name; if so, use it.
	for _, name := range names {
		if !token.Lookup(name).IsKeyword() && !collides(name) {
			return name
		}
	}
	// Otherwise, disambiguate the first name.
	return disambiguate(names[0], collides)
}

// nameInInjector reports whether name collides with any other identifier
// in the current injector.
func (ig *bindGen) nameInInjector(name string) bool {
	if name == ig.errVar {
		return true
	}

	for _, aName := range ig.paramNames {
		if aName == name {
			return true
		}
	}

	for _, l := range ig.localNames {
		if l == name {
			return true
		}
	}

	for _, l := range ig.cleanupNames {
		if l == name {
			return true
		}
	}

	return ig.g.nameInFileScope(name)
}

func (g *Gen) qualifyPkg(pkg *types.Package) string {
	return g.qualifyImport(pkg.Name(), pkg.Path())
}

// unexport converts a name that is potentially exported to an unexported name.
func unexport(name string) string {
	if name == "" {
		return ""
	}

	str, size := utf8.DecodeRuneInString(name)
	if !unicode.IsUpper(str) {
		// foo -> foo
		return name
	}

	rs2, sz2 := utf8.DecodeRuneInString(name[size:])
	if !unicode.IsUpper(rs2) {
		// Foo -> foo
		return string(unicode.ToLower(str)) + name[size:]
	}
	// UPPERWord -> upperWord
	sbuf := new(strings.Builder)
	sbuf.WriteRune(unicode.ToLower(str))

	idx := size
	str, size = rs2, sz2

	for unicode.IsUpper(str) && size > 0 {
		rs2, sz2 := utf8.DecodeRuneInString(name[idx+size:])
		if sz2 > 0 && unicode.IsLower(rs2) {
			break
		}

		idx += size

		sbuf.WriteRune(unicode.ToLower(str))
		str, size = rs2, sz2
	}

	sbuf.WriteString(name[idx:])

	return sbuf.String()
}
